import os
from abc import ABC, abstractmethod
from typing import Callable, Iterator

import fitz
from loguru import logger

from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.schemas import PageInfo
from magic_pdf.data.utils import fitz_doc_to_image
from magic_pdf.filter import classify


class PageableData(ABC):
    @abstractmethod
    def get_image(self) -> dict:
        """Transform data to image."""
        pass

    @abstractmethod
    def get_doc(self) -> fitz.Page:
        """Get the pymudoc page."""
        pass

    @abstractmethod
    def get_page_info(self) -> PageInfo:
        """Get the page info of the page.

        Returns:
            PageInfo: the page info of this page
        """
        pass

    @abstractmethod
    def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
        """draw rectangle.

        Args:
            rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
            color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
            fill (list[float] | None): fill the board with RGB, None means will not fill with color
            fill_opacity (float): opacity of the fill, range from [0, 1]
            width (float): the width of board
            overlay (bool): fill the color in foreground or background. True means fill in background.
        """
        pass

    @abstractmethod
    def insert_text(self, coord, content, fontsize, color):
        """insert text.

        Args:
            coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
            content (str): the text content
            fontsize (int): font size of the text
            color (list[float] | None):  three element tuple which describe the RGB of the board line, None will use the default font color!
        """
        pass


class Dataset(ABC):
    @abstractmethod
    def __len__(self) -> int:
        """The length of the dataset."""
        pass

    @abstractmethod
    def __iter__(self) -> Iterator[PageableData]:
        """Yield the page data."""
        pass

    @abstractmethod
    def supported_methods(self) -> list[SupportedPdfParseMethod]:
        """The methods that this dataset support.

        Returns:
            list[SupportedPdfParseMethod]: The supported methods, Valid methods are: OCR, TXT
        """
        pass

    @abstractmethod
    def data_bits(self) -> bytes:
        """The bits used to create this dataset."""
        pass

    @abstractmethod
    def get_page(self, page_id: int) -> PageableData:
        """Get the page indexed by page_id.

        Args:
            page_id (int): the index of the page

        Returns:
            PageableData: the page doc object
        """
        pass

    @abstractmethod
    def dump_to_file(self, file_path: str):
        """Dump the file

        Args: 
            file_path (str): the file path 
        """
        pass

    @abstractmethod
    def apply(self, proc: Callable, *args, **kwargs):
        """Apply callable method which.

        Args:
            proc (Callable): invoke proc as follows:
                proc(self, *args, **kwargs)

        Returns:
            Any: return the result generated by proc
        """
        pass

    @abstractmethod
    def classify(self) -> SupportedPdfParseMethod:
        """classify the dataset 

        Returns:
            SupportedPdfParseMethod: _description_
        """
        pass

    @abstractmethod
    def clone(self):
        """clone this dataset
        """
        pass


class PymuDocDataset(Dataset):
    def __init__(self, bits: bytes, lang=None):
        """Initialize the dataset, which wraps the pymudoc documents.

        Args:
            bits (bytes): the bytes of the pdf
        """
        self._raw_fitz = fitz.open('pdf', bits)
        self._records = [Doc(v) for v in self._raw_fitz]
        self._data_bits = bits
        self._raw_data = bits

        if lang == '':
            self._lang = None
        else:
            self._lang = lang
            logger.info(f"lang: {lang}")
    def __len__(self) -> int:
        """The page number of the pdf."""
        return len(self._records)

    def __iter__(self) -> Iterator[PageableData]:
        """Yield the page doc object."""
        return iter(self._records)

    def supported_methods(self) -> list[SupportedPdfParseMethod]:
        """The method supported by this dataset.

        Returns:
            list[SupportedPdfParseMethod]: the supported methods
        """
        return [SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.TXT]

    def data_bits(self) -> bytes:
        """The pdf bits used to create this dataset."""
        return self._data_bits

    def get_page(self, page_id: int) -> PageableData:
        """The page doc object.

        Args:
            page_id (int): the page doc index

        Returns:
            PageableData: the page doc object
        """
        return self._records[page_id]

    def dump_to_file(self, file_path: str):
        """Dump the file

        Args: 
            file_path (str): the file path 
        """
        
        dir_name = os.path.dirname(file_path)
        if dir_name not in ('', '.', '..'):
            os.makedirs(dir_name, exist_ok=True)
        self._raw_fitz.save(file_path)

    def apply(self, proc: Callable, *args, **kwargs):
        """Apply callable method which.

        Args:
            proc (Callable): invoke proc as follows:
                proc(dataset, *args, **kwargs)

        Returns:
            Any: return the result generated by proc
        """
        if 'lang' in kwargs and self._lang is not None:
            kwargs['lang'] = self._lang
        return proc(self, *args, **kwargs)

    def classify(self) -> SupportedPdfParseMethod:
        """classify the dataset 

        Returns:
            SupportedPdfParseMethod: _description_
        """
        return classify(self._data_bits)

    def clone(self):
        """clone this dataset
        """
        return PymuDocDataset(self._raw_data)


class ImageDataset(Dataset):
    def __init__(self, bits: bytes):
        """Initialize the dataset, which wraps the pymudoc documents.

        Args:
            bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
        """
        pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
        self._raw_fitz = fitz.open('pdf', pdf_bytes)
        self._records = [Doc(v) for v in self._raw_fitz]
        self._raw_data = bits
        self._data_bits = pdf_bytes

    def __len__(self) -> int:
        """The length of the dataset."""
        return len(self._records)

    def __iter__(self) -> Iterator[PageableData]:
        """Yield the page object."""
        return iter(self._records)

    def supported_methods(self):
        """The method supported by this dataset.

        Returns:
            list[SupportedPdfParseMethod]: the supported methods
        """
        return [SupportedPdfParseMethod.OCR]

    def data_bits(self) -> bytes:
        """The pdf bits used to create this dataset."""
        return self._data_bits

    def get_page(self, page_id: int) -> PageableData:
        """The page doc object.

        Args:
            page_id (int): the page doc index

        Returns:
            PageableData: the page doc object
        """
        return self._records[page_id]

    def dump_to_file(self, file_path: str):
        """Dump the file

        Args: 
            file_path (str): the file path 
        """
        dir_name = os.path.dirname(file_path)
        if dir_name not in ('', '.', '..'):
            os.makedirs(dir_name, exist_ok=True)
        self._raw_fitz.save(file_path)

    def apply(self, proc: Callable, *args, **kwargs):
        """Apply callable method which.

        Args:
            proc (Callable): invoke proc as follows:
                proc(dataset, *args, **kwargs)

        Returns:
            Any: return the result generated by proc
        """
        return proc(self, *args, **kwargs)

    def classify(self) -> SupportedPdfParseMethod:
        """classify the dataset 

        Returns:
            SupportedPdfParseMethod: _description_
        """
        return SupportedPdfParseMethod.OCR

    def clone(self):
        """clone this dataset
        """
        return ImageDataset(self._raw_data)


class MultiFileDataset(Dataset):
    def __init__(self, file_bytes_list: list[bytes], file_extensions: list[str] = None):
        """Initialize the dataset with multiple files (PDFs and images).

        Args:
            file_bytes_list (list[bytes]): list of file bytes (PDF or image) that will be converted to a multi-page pdf
            file_extensions (list[str], optional): list of file extensions (e.g., ['.pdf', '.jpg', '.png']). 
                                                  If provided, will be used to determine file type instead of trying to open files.
        """
        if not file_bytes_list:
            raise ValueError("file_bytes_list cannot be empty")
        
        if file_extensions and len(file_extensions) != len(file_bytes_list):
            raise ValueError("file_extensions length must match file_bytes_list length")

        self._file_extensions = file_extensions
        
        # Track file information
        self._file_info = []
        self._raw_data = file_bytes_list
        
        # Create a new PDF document
        pdf_doc = fitz.open()
        
        # Process each file and insert into the main document
        for i, file_bytes in enumerate(file_bytes_list):
            if not file_bytes:
                raise ValueError(f"File at index {i} has empty bytes")
            
            # Determine file type from extension if provided
            if file_extensions:
                ext = file_extensions[i].lower()
                if ext == '.pdf':
                    file_type = 'pdf'
                    temp_doc = fitz.open('pdf', file_bytes)
                elif ext in ['.jpg', '.jpeg', '.png', '.bmp']:
                    file_type = 'image'
                    temp_doc = fitz.open(stream=file_bytes)
                else:
                    raise ValueError(f"Unsupported file extension: {ext}")
            else:
                # Fallback to original detection method
                try:
                    temp_doc = fitz.open('pdf', file_bytes)
                    file_type = 'pdf'
                except Exception:
                    temp_doc = fitz.open(stream=file_bytes)
                    file_type = 'image'
            
            page_count = len(temp_doc)
            if page_count == 0:
                temp_doc.close()
                continue
            
            start_page = len(pdf_doc)
            
            # Convert to PDF if needed and insert
            if file_type == 'image':
                pdf_bytes = temp_doc.convert_to_pdf()
                temp_pdf = fitz.open('pdf', pdf_bytes)
                pdf_doc.insert_pdf(temp_pdf)
                temp_pdf.close()
            else:
                pdf_doc.insert_pdf(temp_doc)
            
            temp_doc.close()
            
            # Record file information
            self._file_info.append({
                'file_index': i,
                'file_type': file_type,
                'page_count': page_count,
                'start_page': start_page,
                'end_page': start_page + page_count - 1
            })
        
        if len(self._file_info) == 0:
            raise RuntimeError("No valid files were processed")
        
        # Get the final PDF bytes
        self._data_bits = pdf_doc.tobytes()
        pdf_doc.close()
        
        # Reopen the PDF for processing
        self._raw_fitz = fitz.open('pdf', self._data_bits)
        self._records = [Doc(v) for v in self._raw_fitz]

    def __len__(self) -> int:
        """The length of the dataset."""
        return len(self._records)

    def __iter__(self) -> Iterator[PageableData]:
        """Yield the page object."""
        return iter(self._records)

    def supported_methods(self):
        """The method supported by this dataset.

        Returns:
            list[SupportedPdfParseMethod]: the supported methods
        """
        return [SupportedPdfParseMethod.OCR]

    def data_bits(self) -> bytes:
        """The pdf bits used to create this dataset."""
        return self._data_bits

    def get_page(self, page_id: int) -> PageableData:
        """The page doc object.

        Args:
            page_id (int): the page doc index

        Returns:
            PageableData: the page doc object
        """
        return self._records[page_id]

    def dump_to_file(self, file_path: str):
        """Dump the file

        Args: 
            file_path (str): the file path 
        """
        dir_name = os.path.dirname(file_path)
        if dir_name not in ('', '.', '..'):
            os.makedirs(dir_name, exist_ok=True)
        self._raw_fitz.save(file_path)

    def apply(self, proc: Callable, *args, **kwargs):
        """Apply callable method which.

        Args:
            proc (Callable): invoke proc as follows:
                proc(dataset, *args, **kwargs)

        Returns:
            Any: return the result generated by proc
        """
        return proc(self, *args, **kwargs)

    def classify(self) -> SupportedPdfParseMethod:
        """classify the dataset 

        Returns:
            SupportedPdfParseMethod: _description_
        """
        return SupportedPdfParseMethod.OCR

    def clone(self):
        """clone this dataset
        """
        return MultiFileDataset(self._raw_data, file_extensions=self._file_extensions)

    @property
    def file_info(self) -> list[dict]:
        """Get information about each file in the dataset.
        
        Returns:
            list[dict]: List of file information dictionaries containing:
                - file_index: Index of the file in the original input
                - file_type: 'pdf' or 'image'
                - page_count: Number of pages in this file
                - start_page: Starting page index in the combined dataset
                - end_page: Ending page index in the combined dataset
        """
        return self._file_info.copy()

    def get_file_page_count(self, file_index: int) -> int:
        """Get the page count for a specific file.
        
        Args:
            file_index (int): Index of the file
            
        Returns:
            int: Number of pages in the file
        """
        if file_index < 0 or file_index >= len(self._file_info):
            raise IndexError(f"File index {file_index} out of range")
        return self._file_info[file_index]['page_count']

    def export_file_as_dataset(self, file_index: int):
        """Export a specific file as an appropriate Dataset.
        
        Args:
            file_index (int): Index of the file to export
            
        Returns:
            Dataset: ImageDataset for image files, PymuDocDataset for PDF files
        """
        if file_index < 0 or file_index >= len(self._file_info):
            raise IndexError(f"File index {file_index} out of range")
        
        file_bytes = self._raw_data[file_index]
        file_type = self._file_info[file_index]['file_type']
        
        if file_type == 'image':
            return ImageDataset(file_bytes)
        else:  # file_type == 'pdf'
            return PymuDocDataset(file_bytes)


class Doc(PageableData):
    """Initialized with pymudoc object."""

    def __init__(self, doc: fitz.Page):
        self._doc = doc

    def get_image(self):
        """Return the image info.

        Returns:
            dict: {
                img: np.ndarray,
                width: int,
                height: int
            }
        """
        return fitz_doc_to_image(self._doc)

    def get_doc(self) -> fitz.Page:
        """Get the pymudoc object.

        Returns:
            fitz.Page: the pymudoc object
        """
        return self._doc

    def get_page_info(self) -> PageInfo:
        """Get the page info of the page.

        Returns:
            PageInfo: the page info of this page
        """
        page_w = self._doc.rect.width
        page_h = self._doc.rect.height
        return PageInfo(w=page_w, h=page_h)

    def __getattr__(self, name):
        if hasattr(self._doc, name):
            return getattr(self._doc, name)

    def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
        """draw rectangle.

        Args:
            rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
            color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
            fill (list[float] | None): fill the board with RGB, None means will not fill with color
            fill_opacity (float): opacity of the fill, range from [0, 1]
            width (float): the width of board
            overlay (bool): fill the color in foreground or background. True means fill in background.
        """
        self._doc.draw_rect(
            rect_coords,
            color=color,
            fill=fill,
            fill_opacity=fill_opacity,
            width=width,
            overlay=overlay,
        )

    def insert_text(self, coord, content, fontsize, color):
        """insert text.

        Args:
            coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
            content (str): the text content
            fontsize (int): font size of the text
            color (list[float] | None):  three element tuple which describe the RGB of the board line, None will use the default font color!
        """
        self._doc.insert_text(coord, content, fontsize=fontsize, color=color)
